{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_3_feedforward_snn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HzIQBw28NL8h"
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"400\">](https://github.com/jeshraghian/snntorch/)\n",
        "\n",
        "\n",
        "# snnTorch - A Feedforward Spiking Neural Network\n",
        "## Tutorial 3\n",
        "### By Jason K. Eshraghian (www.jasoneshraghian.com)\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_3_feedforward_snn.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width=\"28\">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width=\"80\">](https://github.com/jeshraghian/snntorch/)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H56eQUdm3IuT"
      },
      "source": [
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ep_Qv7kzNOz6"
      },
      "source": [
        "# Introduction\n",
        "In this tutorial, you will:\n",
        "* Learn how to simplify the leaky integrate-and-fire (LIF) neuron to make it deep learning-friendly\n",
        "* Implement a feedforward spiking neural network (SNN)\n",
        "\n",
        "Install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SPQITvDuNNJg"
      },
      "outputs": [],
      "source": [
        "!pip install snntorch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "31NkMZnxBFZ4"
      },
      "outputs": [],
      "source": [
        "# imports\n",
        "import snntorch as snn\n",
        "from snntorch import spikeplot as splt\n",
        "from snntorch import spikegen\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "2kBGXe5K_xWh"
      },
      "outputs": [],
      "source": [
        "#@title Plotting Settings\n",
        "def plot_cur_mem_spk(cur, mem, spk, thr_line=False, vline=False, title=False, ylim_max1=1.25, ylim_max2=1.25):\n",
        "  # Generate Plots\n",
        "  fig, ax = plt.subplots(3, figsize=(8,6), sharex=True, \n",
        "                        gridspec_kw = {'height_ratios': [1, 1, 0.4]})\n",
        "\n",
        "  # Plot input current\n",
        "  ax[0].plot(cur, c=\"tab:orange\")\n",
        "  ax[0].set_ylim([0, ylim_max1])\n",
        "  ax[0].set_xlim([0, 200])\n",
        "  ax[0].set_ylabel(\"Input Current ($I_{in}$)\")\n",
        "  if title:\n",
        "    ax[0].set_title(title)\n",
        "\n",
        "  # Plot membrane potential\n",
        "  ax[1].plot(mem)\n",
        "  ax[1].set_ylim([0, ylim_max2]) \n",
        "  ax[1].set_ylabel(\"Membrane Potential ($U_{mem}$)\")\n",
        "  if thr_line:\n",
        "    ax[1].axhline(y=thr_line, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
        "  plt.xlabel(\"Time step\")\n",
        "\n",
        "  # Plot output spike using spikeplot\n",
        "  splt.raster(spk, ax[2], s=400, c=\"black\", marker=\"|\")\n",
        "  if vline:\n",
        "    ax[2].axvline(x=vline, ymin=0, ymax=6.75, alpha = 0.15, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "  plt.ylabel(\"Output spikes\")\n",
        "  plt.yticks([]) \n",
        "\n",
        "  plt.show()\n",
        "\n",
        "def plot_snn_spikes(spk_in, spk1_rec, spk2_rec, title):\n",
        "  # Generate Plots\n",
        "  fig, ax = plt.subplots(3, figsize=(8,7), sharex=True, \n",
        "                        gridspec_kw = {'height_ratios': [1, 1, 0.4]})\n",
        "\n",
        "  # Plot input spikes\n",
        "  splt.raster(spk_in[:,0], ax[0], s=0.03, c=\"black\")\n",
        "  ax[0].set_ylabel(\"Input Spikes\")\n",
        "  ax[0].set_title(title)\n",
        "\n",
        "  # Plot hidden layer spikes\n",
        "  splt.raster(spk1_rec.reshape(num_steps, -1), ax[1], s = 0.05, c=\"black\")\n",
        "  ax[1].set_ylabel(\"Hidden Layer\")\n",
        "\n",
        "  # Plot output spikes\n",
        "  splt.raster(spk2_rec.reshape(num_steps, -1), ax[2], c=\"black\", marker=\"|\")\n",
        "  ax[2].set_ylabel(\"Output Spikes\")\n",
        "  ax[2].set_ylim([0, 10])\n",
        "\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fw0Zfve-vHGb"
      },
      "source": [
        "# 1. Simplifying the Leaky Integrate-and-Fire Neuron Model\n",
        "In the previous tutorial, we designed our own LIF neuron model. But it was quite complex, and added an array of hyperparameters to tune, including $R$, $C$, $\\Delta t$, $U_{\\rm thr}$, and the choice of reset mechanism. This is a lot to keep track of, and only grows more cumbersome when scaled up to a full-blown SNN. So let's make a few simplfications. \n",
        "\n",
        "\n",
        "## 1.1 The Decay Rate: $\\beta$\n",
        "\n",
        "In the previous tutorial, the Euler method was used to derive the following solution to the passive membrane model:\n",
        "\n",
        "$$U(t+\\Delta t) = (1-\\frac{\\Delta t}{\\tau})U(t) + \\frac{\\Delta t}{\\tau} I_{\\rm in}(t)R \\tag{1}$$\n",
        "\n",
        "Now assume there is no input current, $I_{\\rm in}(t)=0 A$:\n",
        "\n",
        "$$U(t+\\Delta t) = (1-\\frac{\\Delta t}{\\tau})U(t) \\tag{2}$$\n",
        "\n",
        "Let the ratio of subsequent values of $U$, i.e., $U(t+\\Delta t)/U(t)$ be the decay rate of the membrane potential, also known as the inverse time constant:\n",
        "\n",
        "$$U(t+\\Delta t) = \\beta U(t) \\tag{3}$$\n",
        "\n",
        "From $(1)$, this implies that:\n",
        "\n",
        "$$\\beta = (1-\\frac{\\Delta t}{\\tau}) \\tag{4}$$\n",
        "\n",
        "For reasonable accuracy, $\\Delta t << \\tau$.\n",
        "\n",
        "## 1.2 Weighted Input Current\n",
        "If we assume $t$ represents time-steps in a sequence rather than continuous time, then we can set $\\Delta t = 1$. To further reduce the number of hyperparameters, assume $R=1$. From $(4)$, these assumptions lead to:\n",
        "\n",
        "$$\\beta = (1-\\frac{1}{C}) \\implies (1-\\beta)I_{\\rm in} = \\frac{1}{\\tau}I_{\\rm in} \\tag{5}$$\n",
        "\n",
        "The input current is weighted by $(1-\\beta)$. \n",
        "By additionally assuming input current instantaneously contributes to the membrane potential:\n",
        "\n",
        "$$U[t+1] = \\beta U[t] + (1-\\beta)I_{\\rm in}[t+1] \\tag{6}$$\n",
        "\n",
        "Note that the discretization of time means we are assuming that each time bin $t$ is brief enough such that a neuron may only emit a maximum of one spike in this interval.\n",
        "\n",
        "In deep learning, the weighting factor of an input is often a learnable parameter. Taking a step away from the physically viable assumptions made thus far, we subsume the effect of $(1-\\beta)$ from $(6)$ into a learnable weight $W$, and replace $I_{\\rm in}[t]$ accordingly with an input $X[t]$:\n",
        "\n",
        "$$WX[t] = I_{\\rm in}[t] \\tag{7}$$\n",
        "\n",
        "This can be interpreted in the following way. $X[t]$ is an input voltage, or spike, and is scaled by the synaptic conductance of $W$ to generate a current injection to the neuron. This gives us the following result:\n",
        "\n",
        "$$U[t+1] = \\beta U[t] + WX[t+1] \\tag{8}$$\n",
        "\n",
        "\n",
        "In future simulations, the effects of $W$ and $\\beta$ are decoupled.\n",
        "$W$ is a learnable parameter that is updated independently of $\\beta$.\n",
        "\n",
        "\n",
        "## 1.3 Spiking and Reset\n",
        "We now introduce the spiking and reset mechanisms. Recall that if the membrane exceeds the threshold, then the neuron emits an output spike: \n",
        "\n",
        "$$S[t] = \\begin{cases} 1, &\\text{if}~U[t] > U_{\\rm thr} \\\\\n",
        "0, &\\text{otherwise}\\end{cases} \\tag{9}$$\n",
        "\n",
        "If a spike is triggered, the membrane potential should be reset. The *reset-by-subtraction* mechanism is modeled by:\n",
        "\n",
        "$$U[t+1] = \\underbrace{\\beta U[t]}_\\text{decay} + \\underbrace{WX[t+1]}_\\text{input} - \\underbrace{S[t]U_{\\rm thr}}_\\text{reset} \\tag{10}$$\n",
        "\n",
        "As $W$ is a learnable parameter, and $U_{\\rm thr}$ is often just set to $1$ (though can be tuned), this leaves the decay rate $\\beta$ as the only hyperparameter left to be specified. This completes the painful part of this tutorial. \n",
        "\n",
        "> Note: some implementations might make slightly different assumptions. E.g., $S[t] \\rightarrow S[t+1]$ in $(9)$, or $X[t] \\rightarrow X[t+1]$ in $(10)$. This above derivation is what is used in snnTorch as we find it maps intuitively to a recurrent neural network representation, without any change in performance."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LCIwUXt_P9zk"
      },
      "source": [
        "## 1.4 Code Implementation\n",
        "Implementing this neuron in Python looks like this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hJySGu7R-ozz"
      },
      "outputs": [],
      "source": [
        "def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):\n",
        "  spk = (mem > threshold) # if membrane exceeds threshold, spk=1, else, 0\n",
        "  mem = beta * mem + w*x - spk*threshold\n",
        "  return spk, mem"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CEoatctmAPn7"
      },
      "source": [
        "To set $\\beta$, we have the option of either using Equation $(3)$ to define it, or hard-coding it directly. Here, we will use $(3)$ for the sake of a demonstration, but in future, it will just be hard-coded as we are more focused on something that works rather than biological precision.\n",
        "\n",
        "Equation $(3)$ tells us that $\\beta$ is the ratio of membrane potential across two subsequent time steps. Solve this using the continuous time-dependent form of the equation (assuming no current injection), which was derived in [Tutorial 2](https://snntorch.readthedocs.io/en/latest/tutorials/index.html):\n",
        "\n",
        "$$U(t) = U_0e^{-\\frac{t}{\\tau}}$$\n",
        "\n",
        "$U_0$ is the initial membrane potential at $t=0$. Assume the time-dependent equation is computed at discrete steps of $t, (t+\\Delta t), (t+2\\Delta t)~...~$, then we can find the ratio of membrane potential between subsequent steps using:\n",
        "\n",
        "$$\\beta = \\frac{U_0e^{-\\frac{t+\\Delta t}{\\tau}}}{U_0e^{-\\frac{t}{\\tau}}} = \\frac{U_0e^{-\\frac{t + 2\\Delta t}{\\tau}}}{U_0e^{-\\frac{t+\\Delta t}{\\tau}}} =~~...$$\n",
        "$$\\implies \\beta = e^{-\\frac{\\Delta t}{\\tau}} $$\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LwVa4nF7AU35"
      },
      "outputs": [],
      "source": [
        "# set neuronal parameters\n",
        "delta_t = torch.tensor(1e-3)\n",
        "tau = torch.tensor(5e-3)\n",
        "beta = torch.exp(-delta_t/tau)\n",
        "\n",
        "print(f\"The decay rate is: {beta:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vsMZUtDxBKmi"
      },
      "source": [
        "Run a quick simulation to check the neuron behaves correctly in response to a step voltage input:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B_sTGsM9_IHI"
      },
      "outputs": [],
      "source": [
        "num_steps = 200\n",
        "\n",
        "# initialize inputs/outputs + small step current input\n",
        "x = torch.cat((torch.zeros(10), torch.ones(190)*0.5), 0)\n",
        "mem = torch.zeros(1)\n",
        "spk_out = torch.zeros(1)\n",
        "mem_rec = []\n",
        "spk_rec = []\n",
        "\n",
        "# neuron parameters\n",
        "w = 0.4\n",
        "beta = 0.819\n",
        "\n",
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk, mem = leaky_integrate_and_fire(mem, x[step], w=w, beta=beta)\n",
        "  mem_rec.append(mem)\n",
        "  spk_rec.append(spk)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "spk_rec = torch.stack(spk_rec)\n",
        "\n",
        "plot_cur_mem_spk(x*w, mem_rec, spk_rec, thr_line=1,ylim_max1=0.5,\n",
        "                 title=\"LIF Neuron Model With Weighted Step Voltage\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W0AIRdY8-T7K"
      },
      "source": [
        "# 2. `Leaky` Neuron Model in snnTorch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NrhHf-qFB_E4"
      },
      "source": [
        "This same thing can be achieved by instantiating `snn.Leaky` in a similar way to how we used `snn.Lapicque` in the previous tutorial, but with less hyperparameters:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "htrgwDFhCIRS"
      },
      "outputs": [],
      "source": [
        "lif1 = snn.Leaky(beta=0.8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xU2w3JIDCUXu"
      },
      "source": [
        "The neuron model is now stored in `lif1`. To use this neuron:\n",
        "\n",
        "**Inputs**\n",
        "* `cur_in`: each element of $W\\times X[t]$ is sequentially passed as an input\n",
        "* `mem`: the previous step membrane potential, $U[t-1]$, is also passed as input.\n",
        "\n",
        "**Outputs**\n",
        "* `spk_out`: output spike $S[t]$ ('1' if there is a spike; '0' if there is no spike)\n",
        "* `mem`: membrane potential $U[t]$ of the present step\n",
        "\n",
        "These all need to be of type `torch.Tensor`. Note that here, we assume the input current has already been weighted before passing into the `snn.Leaky` neuron. This will make more sense when we construct a network-scale model. Also, equation $(10)$ has been time-shifted back one step without loss of generality.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jYIZW8ZfCT2e"
      },
      "outputs": [],
      "source": [
        "# Small step current input\n",
        "w=0.21\n",
        "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*w), 0)\n",
        "mem = torch.zeros(1)\n",
        "spk = torch.zeros(1)\n",
        "mem_rec = []\n",
        "spk_rec = []\n",
        "\n",
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk, mem = lif1(cur_in[step], mem)\n",
        "  mem_rec.append(mem)\n",
        "  spk_rec.append(spk)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "spk_rec = torch.stack(spk_rec)\n",
        "\n",
        "plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=1, ylim_max1=0.5,\n",
        "                 title=\"snn.Leaky Neuron Model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ijOEZK-FG5dv"
      },
      "source": [
        "This model has the same optional input arguments of `reset_mechanism` and `threshold` as described for Lapicque's neuron model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D15rgPqq-IOd"
      },
      "source": [
        "# 3. A Feedforward Spiking Neural Network\n",
        "\n",
        "So far, we have only considered how a single neuron responds to input stimulus. snnTorch makes it straightforward to scale this up to a deep neural network. In this section, we will create a 3-layer fully-connected neural network of dimensions 784-1000-10. Compared to our simulations so far, each neuron will now integrate over many more incoming input spikes. \n",
        "\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial2/2_8_fcn.png?raw=true' width=\"600\">\n",
        "</center>\n",
        "\n",
        "PyTorch is used to form the connections between neurons, and snnTorch to create the neurons. First, initialize all layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hu4aEh8I-H4Z"
      },
      "outputs": [],
      "source": [
        "# layer parameters\n",
        "num_inputs = 784\n",
        "num_hidden = 1000\n",
        "num_outputs = 10\n",
        "beta = 0.99\n",
        "\n",
        "# initialize layers\n",
        "fc1 = nn.Linear(num_inputs, num_hidden)\n",
        "lif1 = snn.Leaky(beta=beta)\n",
        "fc2 = nn.Linear(num_hidden, num_outputs)\n",
        "lif2 = snn.Leaky(beta=beta)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "el2qvzr8HaXI"
      },
      "source": [
        "Next, initialize the hidden variables and outputs of each spiking neuron. \n",
        "As networks increase in size, this becomes more tedious. The static method `init_leaky()` can be used to take care of this. All neurons in snnTorch have their own initialization methods that follow this same syntax, e.g., `init_lapicque()`. The shape of the hidden states are automatically initialized based on the input data dimensions during the first forward pass. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RChZjgMdzyFZ"
      },
      "outputs": [],
      "source": [
        "# Initialize hidden states\n",
        "mem1 = lif1.init_leaky()\n",
        "mem2 = lif2.init_leaky()\n",
        "\n",
        "# record outputs\n",
        "mem2_rec = []\n",
        "spk1_rec = []\n",
        "spk2_rec = []"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6hV20mQhH8Ir"
      },
      "source": [
        "Create an input spike train to pass to the network. There are 200 time steps to simulate across 784 input neurons, i.e., the input originally has dimensions of $200 \\times 784$. However, neural nets typically process data in minibatches. snnTorch uses time-first dimensionality: \n",
        "\n",
        "[$time \\times batch\\_size \\times feature\\_dimensions$]\n",
        "\n",
        "So 'unsqueeze' the input along `dim=1` to indicate 'one batch' of data. The dimensions of this input tensor must be 200 $\\times$ 1 $\\times$ 784:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y8NHscQrH_cI"
      },
      "outputs": [],
      "source": [
        "spk_in = spikegen.rate_conv(torch.rand((200, 784))).unsqueeze(1)\n",
        "print(f\"Dimensions of spk_in: {spk_in.size()}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-XuS4V_HI_kn"
      },
      "source": [
        "Now it's finally time to run a full simulation. \n",
        "An intuitive way to think about how PyTorch and snnTorch work together is that PyTorch routes the neurons together, and snnTorch loads the results into spiking neuron models. In terms of coding up a network, these spiking neurons can be treated like time-varying activation functions.\n",
        "\n",
        "Here is a sequential account of what's going on:\n",
        "\n",
        "*  The $i^{th}$ input from `spk_in` to the $j^{th}$ neuron is weighted by the parameters initialized in `nn.Linear`: $X_{i} \\times W_{ij}$\n",
        "* This generates the input current term from Equation $(10)$, contributing to $U[t+1]$ of the spiking neuron\n",
        "* If $U[t+1] > U_{\\rm thr}$, then a spike is triggered from this neuron\n",
        "* This spike is weighted by the second layer weight, and the above process is repeated for all inputs, weights, and neurons.\n",
        "* If there is no spike, then nothing is passed to the post-synaptic neuron.\n",
        "\n",
        "The only difference from our simulations thus far is that we are now scaling the input current with a weight generated by `nn.Linear`, rather than manually setting $W$ ourselves."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JR6XAuL4I-NH"
      },
      "outputs": [],
      "source": [
        "# network simulation\n",
        "for step in range(num_steps):\n",
        "    cur1 = fc1(spk_in[step]) # post-synaptic current <-- spk_in x weight\n",
        "    spk1, mem1 = lif1(cur1, mem1) # mem[t+1] <--post-syn current + decayed membrane\n",
        "    cur2 = fc2(spk1)\n",
        "    spk2, mem2 = lif2(cur2, mem2)\n",
        "\n",
        "    mem2_rec.append(mem2)\n",
        "    spk1_rec.append(spk1)\n",
        "    spk2_rec.append(spk2)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem2_rec = torch.stack(mem2_rec)\n",
        "spk1_rec = torch.stack(spk1_rec)\n",
        "spk2_rec = torch.stack(spk2_rec)\n",
        "\n",
        "plot_snn_spikes(spk_in, spk1_rec, spk2_rec, \"Fully Connected Spiking Neural Network\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5o4s2HEBNBWc"
      },
      "source": [
        "> If you run into errors, then try re-initializing your networks and parameters.\n",
        "\n",
        "\n",
        "At this stage, the spikes don't have any real meaning. The inputs and weights are all randomly initialized, and no training has taken place. But the spikes  should appear to be propagating from the first layer through to the output. If you are not seeing any spikes, then you might have been unlucky in the weight initialization lottery - you might want to try re-running the last four code-blocks. \n",
        "\n",
        "`spikeplot.spike_count` can create a spike counter of the output layer. The following animation will take some time to generate. <br>\n",
        "\n",
        "> Note: if you are running the notebook locally on your desktop, please uncomment the line below and modify the path to your ffmpeg.exe\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FXhPeou8NxS0"
      },
      "outputs": [],
      "source": [
        "from IPython.display import HTML\n",
        "\n",
        "fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))\n",
        "labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']\n",
        "spk2_rec = spk2_rec.squeeze(1).detach().cpu()\n",
        "\n",
        "# plt.rcParams['animation.ffmpeg_path'] = 'C:\\\\path\\\\to\\\\your\\\\ffmpeg.exe'\n",
        "\n",
        "#  Plot spike count histogram\n",
        "anim = splt.spike_count(spk2_rec, fig, ax, labels=labels, animate=True)\n",
        "HTML(anim.to_html5_video())\n",
        "# anim.save(\"spike_bar.mp4\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ESmXh1MkN0Cl"
      },
      "source": [
        "`spikeplot.traces` lets you visualize the membrane potential traces. We will plot 9 out of 10 output neurons. \n",
        "Compare it to the animation and raster plot above to see if you can match the traces to the neuron. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "549a-aMvOnWc"
      },
      "outputs": [],
      "source": [
        "# plot membrane potential traces\n",
        "splt.traces(mem2_rec.squeeze(1), spk=spk2_rec.squeeze(1))\n",
        "fig = plt.gcf() \n",
        "fig.set_size_inches(8, 6)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "erpkjgKCV8jW"
      },
      "source": [
        "It is fairly normal if some neurons are firing while others are completely dead. Again, none of these spikes have any real meaning until the weights have been trained."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BGCY7g7R0Usy"
      },
      "source": [
        "# Conclusion\n",
        "\n",
        "That covers how to simplify the leaky integrate-and-fire neuron model, and then using it to build a spiking neural network. In practice, we will almost always prefer to use `snn.Leaky` over `snn.Lapicque` for training networks, as there is a smaller hyperparameter search space. \n",
        "\n",
        "[Tutorial 4](https://snntorch.readthedocs.io/en/latest/tutorials/index.html) goes into detail with the 2nd-order `snn.Synaptic` and `snn.Alpha` models.\n",
        "This next tutorials is not necessary for training a network, so if you wish to go straight to deep learning with snnTorch, then skip ahead to [Tutorial 5](https://snntorch.readthedocs.io/en/latest/tutorials/index.html).\n",
        "\n",
        "For reference, the documentation [can be found here](https://snntorch.readthedocs.io/en/latest/snntorch.html).\n",
        "\n",
        "If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OTRCyd0Xa-QK"
      },
      "source": [
        "## Further Reading\n",
        "* [Check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)\n",
        "* [snnTorch documentation](https://snntorch.readthedocs.io/en/latest/snntorch.html) of the Lapicque, Leaky, Synaptic, and Alpha models\n",
        "* [*Neuronal Dynamics:\n",
        "From single neurons to networks and models of cognition*](https://neuronaldynamics.epfl.ch/index.html) by\n",
        "Wulfram Gerstner, Werner M. Kistler, Richard Naud and Liam Paninski.\n",
        "* [Theoretical Neuroscience: Computational and Mathematical Modeling of Neural Systems](https://mitpress.mit.edu/books/theoretical-neuroscience) by Laurence F. Abbott and Peter Dayan"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "include_colab_link": true,
      "name": "Copy of Untitled17.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.8"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
